{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "41e39522-337e-477e-aeef-01766f2b9911",
   "metadata": {},
   "source": [
    "# ch3-使用embedding模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8f6f95a-960d-4ddd-bab2-2b6326e862e6",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 1.本节实战通过不同方式来调用embedding模型以及embedding的操作\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c1d0617-a55c-44d5-a9a0-a9bd4e7a5ccb",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 内容\n",
    "- <font size=5>3种embedding调用方式</font>:transformers和sentence_transformers、langchain\n",
    "\n",
    "``` shell\n",
    "pip install transformers\n",
    "pip install sentence_transformers\n",
    "pip install langchain\n",
    "```\n",
    "\n",
    "- <font size=5>2种embedding的基本操作</font>：相似度计算和聚类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3edc7a14-af3c-4400-ae77-b18c6351b3da",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2.transformers 方式\n",
    "- Transformers库是由Hugging Face开发的一个非常流行的Python库，专门用于自然语言处理（NLP）任务，最出名就是实现了transformer架构\n",
    "- embedding模型也是transformer架构\n",
    "- 也可以通过transformers库来调用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "69868ff4-b10f-412d-8024-65c85e507ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "from sentence_transformers.util import cos_sim\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "fd05a88d-9143-49ca-8d1a-e2e7ee57c3c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_texts = [\n",
    "    \"中国的首都是哪里\",\n",
    "    \"你喜欢去哪里旅游\",\n",
    "    \"北京\",\n",
    "    \"今天中午吃什么\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "94322058-9173-4b4d-b028-8f7fe7975e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = './data/llm_app/embedding_models/gte-large-zh/'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "model = AutoModel.from_pretrained(model_path, device_map='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "e8503b66-f14b-405e-9d8d-e14d180fe403",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_tokens = tokenizer(input_texts,\n",
    "                        max_length=30,\n",
    "                        padding=True,\n",
    "                        truncation=True,\n",
    "                        return_tensors='pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "7b9bcc6c-e0e8-409e-8209-8c1998f16298",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]', '中', '国', '的', '首', '都', '是', '哪', '里', '[SEP]']"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_tokens[0].tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "585e8a8d-4aec-44d6-b0a2-24d8397df0eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[CLS]', '北', '京', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']\n"
     ]
    }
   ],
   "source": [
    "print(batch_tokens[2].tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "3eb62cb9-8b40-4e06-bdfe-3eaa1d2f7ba1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 101,  704, 1744, 4638, 7674, 6963, 3221, 1525, 7027,  102])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_tokens.input_ids[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "8b967ffe-a306-478b-9b7d-c198e4ee2dc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 101, 1266,  776,  102,    0,    0,    0,    0,    0,    0])"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_tokens.input_ids[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "592285fe-3def-4e0c-aea5-2ec8a2a8ea62",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(**batch_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "933e57b2-6920-46e9-87d2-4ef3b3a2e8d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 2.0181,  0.4087,  0.1180,  ...,  1.0200,  0.4325, -0.5421],\n",
       "         [ 1.5130,  0.2715, -0.0208,  ...,  0.5583,  0.4469, -0.1568],\n",
       "         [ 2.0624, -0.1686,  0.0687,  ...,  0.9885,  0.5314, -0.1303],\n",
       "         ...,\n",
       "         [ 1.4790,  0.3036, -0.0824,  ...,  0.8273,  0.6185, -0.1685],\n",
       "         [ 1.6145,  0.5669,  0.0285,  ...,  1.0715,  0.6865, -0.3383],\n",
       "         [ 2.0181,  0.4085,  0.1180,  ...,  1.0200,  0.4326, -0.5420]],\n",
       "\n",
       "        [[ 0.4519, -1.1412,  0.0964,  ...,  0.2385,  0.5646, -0.9434],\n",
       "         [ 0.0185, -0.7171,  0.1632,  ...,  0.7637, -0.1800, -0.2161],\n",
       "         [-0.3061, -0.8769,  0.2718,  ...,  0.6523,  0.3814, -0.8347],\n",
       "         ...,\n",
       "         [ 0.2691, -1.1795,  0.2144,  ...,  0.2852,  0.1001, -0.7169],\n",
       "         [ 0.2553, -1.0392,  0.0997,  ...,  0.6434,  0.3948, -0.6787],\n",
       "         [ 0.4520, -1.1413,  0.0967,  ...,  0.2389,  0.5644, -0.9432]],\n",
       "\n",
       "        [[ 0.6939,  0.8812, -0.5522,  ...,  1.1094, -0.7641, -0.8137],\n",
       "         [ 0.4705,  0.1928, -0.2424,  ...,  0.9109, -0.5743, -0.8637],\n",
       "         [ 0.5483,  0.2564, -0.1667,  ...,  0.9154, -0.4747, -0.9967],\n",
       "         ...,\n",
       "         [ 0.4414,  0.5646, -0.4097,  ...,  1.0151, -0.1386, -0.7399],\n",
       "         [ 0.3782,  0.5225, -0.4428,  ...,  1.1022, -0.2059, -0.7586],\n",
       "         [ 0.5822,  0.7111, -0.7050,  ...,  1.0280, -0.0522, -0.8880]],\n",
       "\n",
       "        [[ 1.3131, -0.3945, -0.9939,  ..., -0.7322,  0.9566, -1.5303],\n",
       "         [-0.0722, -0.2521, -0.3308,  ..., -0.6924,  0.9003, -1.1423],\n",
       "         [ 0.4460, -0.4178, -0.6173,  ..., -1.0655,  0.5030, -0.9752],\n",
       "         ...,\n",
       "         [ 0.1558, -0.5395, -0.5539,  ..., -0.8361,  0.8982, -0.9885],\n",
       "         [ 1.3125, -0.3946, -0.9941,  ..., -0.7323,  0.9566, -1.5303],\n",
       "         [ 0.7531, -0.4663, -1.2458,  ..., -0.4931,  0.9329, -1.2392]]],\n",
       "       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.6892, -0.6961, -0.1464,  ...,  0.0733, -0.0624, -0.7050],\n",
       "        [-0.1941,  0.1064,  0.1948,  ..., -0.5334,  0.1855, -0.0094],\n",
       "        [-0.3968, -0.7202,  0.1894,  ...,  0.4823, -0.0780, -0.8869],\n",
       "        [-0.3021, -0.6326, -0.1963,  ..., -0.4366, -0.2662,  0.1489]],\n",
       "       grad_fn=<TanhBackward0>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "64f20cd5-df20-47dd-a86e-2628ac284945",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10, 1024])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "bbbda7c1-2d9c-4d6c-b721-c5f28549a60f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 1024])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.pooler_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "0ec2dd40-e177-4c0e-baea-8524ee03662d",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = outputs.last_hidden_state[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "62cb4369-97cb-41db-a105-03e81d34019c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 1024])"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "a5907c15-b1a5-4ef2-b858-090b47424d57",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = F.normalize(embeddings, p=2, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "4e011b12-6093-435c-8fe8-043634279079",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "中国的首都是哪里 你喜欢去哪里旅游 tensor([[0.3295]], grad_fn=<MmBackward0>)\n",
      "中国的首都是哪里 北京 tensor([[0.6354]], grad_fn=<MmBackward0>)\n",
      "中国的首都是哪里 今天中午吃什么 tensor([[0.3248]], grad_fn=<MmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, 4):\n",
    "    print(input_texts[0], input_texts[i], cos_sim(embeddings[0], embeddings[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a20eda25-4006-4e2d-85fe-e9d11f66c56d",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 3. sentence_transformers 方式\n",
    "- Sentence-Transformers是一个基于PyTorch和Transformers的Python库，它专门用于句子、文本和图像嵌入（Embedding）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "64731ed2-8c04-4960-b5d3-5f812c14b4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "from sentence_transformers.util import cos_sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "38cb9879-f9bb-48d8-a3d8-02005b373b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_texts = [\n",
    "    \"中国的首都是哪里\",\n",
    "    \"你喜欢去哪里旅游\",\n",
    "    \"北京\",\n",
    "    \"今天中午吃什么\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "9958d6cd-64e0-4b9a-8ec4-fccd27034a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = './data/llm_app/embedding_models/gte-large-zh/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "19afa4c2-48b1-4e09-b37d-72a6b5f3f525",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "931c717d-11f3-441d-a9c0-b2a3aa4486e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = model.encode(input_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "970aea44-f904-451b-95b8-0ce037bd5c3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 1024)"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "5b121cb5-568a-4175-b995-5c145daba837",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "中国的首都是哪里 你喜欢去哪里旅游 tensor([[0.3295]])\n",
      "中国的首都是哪里 北京 tensor([[0.6354]])\n",
      "中国的首都是哪里 今天中午吃什么 tensor([[0.3248]])\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, 4):\n",
    "    print(input_texts[0], input_texts[i], cos_sim(embeddings[0], embeddings[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90579ba5-6c27-4993-91be-bf2d8f627980",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 4. langchain 方式 \n",
    "- <font size=5>对SentenceTransformer的封装</font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "04be2ce2-5cf4-4eee-be79-62d2ed69b028",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
    "from sentence_transformers.util import cos_sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "a36133bf-265b-4c8f-826d-29681e090a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = './data/llm_app/embedding_models/gte-large-zh/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "cc23e18b-b025-4396-af9f-f0b3746f8142",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HuggingFaceEmbeddings(model_name=model_path,\n",
    "                             model_kwargs={\"device\": \"cpu\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "de4ed641-3d2c-415e-824a-7f3488b34185",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = model.embed_documents(input_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "55ad5a1a-d6d5-4322-a81f-c5296444b5c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 1024)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "embeddings = np.array(embeddings)\n",
    "print(embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "8148a47d-1243-4c36-8660-6ed07bbd0c84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "中国的首都是哪里 你喜欢去哪里旅游 tensor([[0.3295]], dtype=torch.float64)\n",
      "中国的首都是哪里 北京 tensor([[0.6354]], dtype=torch.float64)\n",
      "中国的首都是哪里 今天中午吃什么 tensor([[0.3248]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, 4):\n",
    "    print(input_texts[0], input_texts[i], cos_sim(embeddings[0], embeddings[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "705a7375-aeab-420c-86d3-6e52813a2464",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 5.embedding 操作"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d6f8bf9-630b-48f3-9297-34e00ebbc6a9",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 5.1 距离计算"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8558b7ba-9658-4b17-9e58-287cf40a0633",
   "metadata": {},
   "source": [
    "- <font size=5>余弦相似度</font>\n",
    "\n",
    "![](./data/cos.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "f7669b72-cbb3-404a-a9fe-4094c490a9a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = embeddings[0]\n",
    "b = embeddings[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "79b21b7e-6c7e-4648-851f-86a95bb402ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import dot\n",
    "from numpy.linalg import norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "8bc18e79-12c1-4187-8d22-38cf977a90d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cos_a_b = dot(a, b) / (norm(a) * norm(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "9bd10914-31f2-467d-8ac6-39cafc0f1f9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6353580110345967 tensor([[0.6354]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(cos_a_b, cos_sim(a, b))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00853666-5979-418c-a89b-cd90f6271eb9",
   "metadata": {},
   "source": [
    "- ><font size=5>欧几里得距离</font>\n",
    "\n",
    "![](./data/l2.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "0db9431a-a70a-4c23-8472-ff4dd805a873",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.853981260361025"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm(a - b)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a7e8d75-9711-4234-b292-82ad7b28bceb",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 5.2聚类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "b9dedced-87f1-46b7-ac15-f1e1a4298b9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = ['苹果', '菠萝', '西瓜', '斑马', '大象', '老鼠']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "63596b0c-e5d2-4d00-ba4f-012c32285b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_embeddings = model.embed_documents(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "af71bc1c-2e83-4e73-ae71-4f1b557576a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "106a1eaf-8178-471b-ad0f-240a38a92b1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "kmeans = KMeans(n_clusters=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "fa93ecdf-4a0b-43ed-9c6c-adcac32293b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-2 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-2 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-2 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-2 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-2 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KMeans(n_clusters=2)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;KMeans<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.KMeans.html\">?<span>Documentation for KMeans</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>KMeans(n_clusters=2)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "KMeans(n_clusters=2)"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kmeans.fit(output_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "6daf5bc6-c8bd-4a64-987f-efb722565ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "label = kmeans.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "1e3eae8e-d97e-49b2-b9b0-411905816860",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cls(苹果) = 1\n",
      "cls(菠萝) = 1\n",
      "cls(西瓜) = 1\n",
      "cls(斑马) = 0\n",
      "cls(大象) = 0\n",
      "cls(老鼠) = 0\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(texts)):\n",
    "    print(f\"cls({texts[i]}) = {label[i]}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "llm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
